import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math

class nist_layer(nn.Module):
    def __init__(self, indim, outdim, neuroseed_factor, *args, **kwargs):
        super(nist_layer, self).__init__()
        self.indim = indim
        self.outdim = outdim
        self.neuroseed_factor = neuroseed_factor
        mask = torch.zeros([indim, outdim])
        if self.neuroseed_factor > self.outdim:
            raise ValueError("Neuroseed Factor Cannot Exceed Outdim")
        for i in range(self.indim):
            for k in range(self.neuroseed_factor):
                j = (i + k) % self.outdim
                mask[i,j] = 1
        self.register_buffer("mask", mask)  # NOn-Trainable

        self.n_param = len(mask.nonzero())
        self.n_param_dense = torch.numel(mask)          # print numbe of weights later
        print("number of Weights(Sparse): {}, \nNumber of Weights(Dense): {}". format(self.n_param, self.n_param_dense))
        self.weight = nn.Parameter(torch.Tensor(self.indim, self.outdim))
        self.bias = nn.Parameter(torch.Tensor(self.outdim))
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        self.weight.data *= self.mask

        fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
        bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
        nn.init.uniform_(self.bias, -bound, bound)

    def forward(self, x):
        self.weight_core = self.weight * self.mask
        #print(f"x shape: {x.shape}, weight_core shape: {self.weight_core.shape}")
        # flatten all but last dim to 2D
        orig_shape = x.shape
        x_2d = x.view(-1, orig_shape[-1])  # flatten batch and sequence dims

        y_pred = torch.matmul(x_2d, self.weight_core)  # (N, out_features)
        y_pred += self.bias  # broadcast bias

        # reshape back to original leading dims + out_features
        out_shape = orig_shape[:-1] + (self.weight_core.shape[1],)
        y_pred = y_pred.view(out_shape)
        return y_pred
    def apply_pruning_mask(self, new_mask):
        self.mask.copy_(new_mask)
